# import required packages
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import mean_squared_error

# read the data
# df = pd.read_excel("UCI_air/AirQualityUCI.xlsx", parse_dates=[['Date', 'Time']])
#
# #check the dtypes
# print(df.dtypes)
#
# df['Date_Time'] = pd.to_datetime(df.Date_Time , format = '%d/%m/%Y %H.%M.%S')
# data = df.drop(['Date_Time'], axis=1)
# data.index = df.Date_Time

data = pd.read_csv("../data/econ/表2.csv", encoding='gbk')
# data = df.drop(['Unnamed: 0'], axis=1)
# #missing value treatment
# cols = data.columns
data = data.drop([data.columns[0]], axis=1)
# for j in cols:
#     for i in range(0,len(data)):
#        if data[j][i] == -200:
#            data[j][i] = data[j][i-1]

print(data.columns)

# checking stationarity
from statsmodels.tsa.vector_ar.vecm import coint_johansen

# since the test works for only 12 variables, I have randomly dropped
# in the next iteration, I would drop another and check the eigenvalues
# johan_test_temp = data.drop([ 'CO(GT)'], axis=1)

# print(coint_johansen(johan_test_temp,-1,1).eig)

# creating the train and validation set
train = data[:int(0.8 * (len(data)))]
valid = data[int(0.8 * (len(data))):]

# fit the model
from statsmodels.tsa.vector_ar.var_model import VAR

model = VAR(endog=train)
model_fit = model.fit()

# make prediction on validation
prediction = model_fit.forecast(model_fit.y, steps=len(valid))

#
# #converting predictions to dataframe
# pred = pd.DataFrame(index=range(0,len(prediction)),columns=[cols])
# for j in range(0,13):
#     for i in range(0, len(prediction)):
#        pred.iloc[i][j] = prediction[i][j]
# #check rmse
# for i in range(len(cols)):
#     print('rmse value for', cols[i], 'is : ', np.sqrt(mean_squared_error(pred.iloc[i], valid.iloc[i])))

# make final predictions

# steps = 3
# num_row = 17
steps = 90
num_row = 183

model = VAR(endog=data)
model_fit = model.fit()
yhat = model_fit.forecast(model_fit.y, steps=steps)

new_columns = data.columns[1:]
dict = {}
for i in range(len(new_columns)):
    column_name = new_columns[i]
    column_prev = []
    column_pred = []
    for j in range(num_row):
        column_prev.append(model_fit.y[j][i])
    for k in range(steps):
        column_pred.append(yhat[k][i])
    dict[column_name] = column_prev + column_pred
    plt.plot(range(num_row), column_prev)
    plt.plot(range(num_row, num_row + steps), column_pred)
    plt.title(column_name)
    plt.show()

df_ult = pd.DataFrame(dict)
df_ult.to_csv("../data/econ/var_result2.csv")
print("saved")
plt.figure(figsize=(30, 18))
plt.plot(range(num_row), model_fit.y)
plt.plot(range(num_row, num_row + steps), yhat)
plt.show()

plt.plot(yhat)
plt.show()
print(model_fit.y.shape)
print(yhat)

# def ARIMA_FILL(data,columns):
#     total_result = []
#     for i in range(len(columns)):
#         if columns[i] == "time":
#             total_result.append(np.array[202010,202011,202012,202101,202202,202203])
#         elif columns[i] == "_record_id_":
#             total_result.append(np.arange(num_row+1,num_row+1+forecast_step))
#         else:
#             data_value = np.asarray(data[[columns[i]]])
#             model = ARIMA(data_value, (1, 1, 0)).fit()
#             result = model.forecast(forecast_step)
#             total_result.append(result[0])
#     a = (pd.DataFrame(total_result)).T
#     a.columns = columns
#     final_result = data.append(a, ignore_index=True)
#     return final_result
#
# df = ARIMA_FILL(data,data.columns)
#
# df = ARIMA_FILL(data,data.columns)
#
# def ARIMA_FILL(data):
#     total_result = []
# 	columns = data.columns
#     for i in range(len(columns)):
#         if columns[i] == "time":
#             total_result.append(np.array([202010,202011,202012,202101,202102,202103])) #[todo]
#         elif columns[i] == "_record_id_":
#             total_result.append(np.arange(num_row+1,num_row+1+forecast_step))
#         else:
#             data_value = np.asarray(data[[columns[i]]])
#             model = ARIMA(data_value, (1, 1, 0)).fit()
#             result = model.forecast(forecast_step)
#             total_result.append(result[0])
#     a = (pd.DataFrame(total_result)).T
#     a.columns = columns
#     final_result = data.append(a, ignore_index=True)
#     return final_result
